import torch
import torch.nn as nn
import torch.optim as optim
import layers.nf_nn as fnn
from generic.data_util import get_nf_input, ICEHOCKEY_ACTIONS


def build_maf(agent, act='relu'):
    modules = []
    if agent.maf_flow_type == 'maf':
        for _ in range(agent.maf_num_blocks):
            modules += [
                fnn.MADE(num_inputs=agent.maf_num_inputs,
                         num_hidden=agent.maf_num_hidden,
                         num_cond_inputs=agent.maf_num_cond_inputs,
                         act=act),
                fnn.BatchNormFlow(agent.maf_num_inputs),
                fnn.Reverse(agent.maf_num_inputs)
            ]
    elif agent.maf_flow_type == 'maf-split':
        for _ in range(agent.maf_num_blocks):
            modules += [
                fnn.MADESplit(num_inputs=agent.maf_num_inputs,
                              num_hidden=agent.maf_num_hidden,
                              num_cond_inputs=agent.maf_num_cond_inputs,
                              s_act='tanh', t_act='relu'),
                fnn.BatchNormFlow(agent.maf_num_inputs),
                fnn.Reverse(agent.maf_num_inputs)
            ]
    model = fnn.FlowSequential(*modules)

    for module in model.modules():
        if isinstance(module, nn.Linear):
            nn.init.orthogonal_(module.weight)
            if hasattr(module, 'bias') and module.bias is not None:
                module.bias.data.fill_(0)
    model.to(agent.device)
    agent.maf_model = model
    agent.maf_optim = optim.Adam(model.parameters(), lr=agent.maf_lr, weight_decay=1e-6)


# def get_nf_input(state_action, trace, apply_history, sanity_check_msg=None):
#     batch_size = len(state_action)
#     if apply_history:
#         tgt_data = torch.reshape(torch.stack(state_action)[:, :, :], shape=(batch_size, -1))
#     else:
#         tgt_data = []
#         for i in range(batch_size):
#             # if sanity_check_msg is None:
#             #     tgt_data.append(state_action[i][trace[i] - 1, :])
#             # elif 'location' in sanity_check_msg and 'ha' in sanity_check_msg:
#             #     state_location = state_action[i][trace[i] - 1, :2]
#             #     state_home_away = state_action[i][trace[i] - 1, 9:11]
#             #     action = state_action[i][trace[i] - 1, 12:]
#             #     tgt_data.append(torch.cat([state_location, state_home_away, action], dim=0))
#             # else:
#             #     raise ValueError("Unknown sanity_check_msg".format(sanity_check_msg))
#             tgt_data.append(state_action[i][trace[i] - 1, :])
#         tgt_data = torch.stack(tgt_data)
#     return tgt_data


def update_maf(agent, batch, sanity_check_msg, batch_values_cond=None):
    batch_size = len(batch.state_action)
    tgt_data = []
    # cond_data = []
    # for i in range(batch_size):
    #     tgt_data.append(batch.state_action[i][batch.trace[i]-1, :].to(agent.device))
    # if batch.trace[i]-2 >= 0:
    #     cond_data.append(batch.state_action[i][batch.trace[i]-2, :].to(agent.device))
    # else:
    #     cond_data.append(torch.zeros(agent.input_dim, dtype=torch.float32).to(agent.device))

    tgt_data, tgt_cond = get_nf_input(
        agent=agent,
        state_action=batch.state_action,
        values_cond=batch_values_cond,
        trace=batch.trace,
        apply_history=agent.maf_apply_history,
        sanity_check_msg=sanity_check_msg
    )
    # cond_data = torch.stack(cond_data)  # s_t-1, a_t-1
    # print(tgt_data[0])
    # print(tgt_data[1])
    agent.maf_optim.zero_grad()
    m_loss, log_prob = agent.maf_model.log_probs(inputs=tgt_data, cond_inputs=tgt_cond)
    # .mean()
    loss = -m_loss.mean()
    log_prob = log_prob.mean()
    loss.backward()
    agent.maf_optim.step()

    return loss, log_prob

    # pbar.update(data.size(0))
    # pbar.set_description('Train, Log likelihood in nats: {:.6f}'.format(
    #     -train_loss / (batch_idx + 1)))
    #
    # writer.add_scalar('training/loss', loss.item(), global_step)
    # global_step += 1


def validate_maf(agent, batch, sanity_check_msg, batch_values_cond=None):
    # batch_size = len(batch.state_action)
    # tgt_data = []
    # cond_data = []
    # for i in range(batch_size):
    #     tgt_data.append(batch.state_action[i][batch.trace[i]-1, :].to(agent.device))
    # if batch.trace[i]-2 >= 0:
    #     cond_data.append(batch.state_action[i][batch.trace[i]-2, :].to(agent.device))
    # else:
    #     cond_data.append(torch.zeros(agent.input_dim, dtype=torch.float32).to(agent.device))

    tgt_data, tgt_cond = get_nf_input(
        agent=agent,
        state_action=batch.state_action,
        values_cond=batch_values_cond,
        trace=batch.trace,
        apply_history=agent.maf_apply_history,
        sanity_check_msg=sanity_check_msg)
    # cond_data = torch.stack(cond_data)  # s_t-1, a_t-1

    with torch.no_grad():
        m_loss, log_prob = agent.maf_model.log_probs(inputs=tgt_data, cond_inputs=tgt_cond)  # sum up batch loss
    loss = -m_loss.squeeze(1)
    log_prob = log_prob.squeeze(1)
    return loss.detach().cpu().numpy(), log_prob.detach().cpu().numpy()
